{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "49100733-1c5e-4ff2-9051-208065a6a86f",
   "metadata": {},
   "source": [
    "### Wordpiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f0f8aa8-e3db-42a5-be41-cde4d6373463",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stats: defaultdict(<class 'int'>, {'我': 1, '喜欢': 2, '吃': 2, '苹果': 1, '他': 1, '不': 1, '苹果派': 1, 'I': 1, 'like': 1, 'to': 1, 'eat': 1, 'apples': 1, 'She': 1, 'has': 1, 'a': 2, 'cute': 2, 'cat': 1, 'you': 2, 'are': 1, 'very': 1, 'give': 1, 'hug': 1})\n",
      "alphabet: ['##a', '##e', '##g', '##h', '##i', '##k', '##l', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##y', '##果', '##欢', '##派', 'I', 'S', 'a', 'c', 'e', 'g', 'h', 'l', 't', 'v', 'y', '不', '他', '吃', '喜', '我', '苹']\n"
     ]
    }
   ],
   "source": [
    "sentences = [\n",
    "    \"我\",\n",
    "    \"喜欢\",\n",
    "    \"吃\",\n",
    "    \"苹果\",\n",
    "    \"他\",\n",
    "    \"不\",\n",
    "    \"喜欢\",\n",
    "    \"吃\",\n",
    "    \"苹果派\",\n",
    "    \"I like to eat apples\",\n",
    "    \"She has a cute cat\",\n",
    "    \"you are very cute\",\n",
    "    \"give you a hug\",\n",
    "]\n",
    "\n",
    "from collections import defaultdict\n",
    "# 构建频率统计\n",
    "def build_stats(sentences):\n",
    "    stats = defaultdict(int)\n",
    "    for sentence in sentences:\n",
    "        symbols = sentence.split()\n",
    "        for symbol in symbols:\n",
    "            stats[symbol] += 1\n",
    "    return stats\n",
    "\n",
    "stats = build_stats(sentences)\n",
    "print(\"stats:\", stats)\n",
    "\n",
    "alphabet = []\n",
    "for word in stats.keys():\n",
    "    if word[0] not in alphabet:\n",
    "        alphabet.append(word[0])\n",
    "    for letter in word[1:]:\n",
    "        if f\"##{letter}\" not in alphabet:\n",
    "            alphabet.append(f\"##{letter}\")\n",
    "\n",
    "alphabet.sort()\n",
    "# 初始词表\n",
    "vocab = alphabet.copy()\n",
    "print(\"alphabet:\", alphabet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6282ac2-5549-4f55-9aac-bad034c7a0f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "splits: {'我': ['我'], '喜欢': ['喜', '##欢'], '吃': ['吃'], '苹果': ['苹', '##果'], '他': ['他'], '不': ['不'], '苹果派': ['苹', '##果', '##派'], 'I': ['I'], 'like': ['l', '##i', '##k', '##e'], 'to': ['t', '##o'], 'eat': ['e', '##a', '##t'], 'apples': ['a', '##p', '##p', '##l', '##e', '##s'], 'She': ['S', '##h', '##e'], 'has': ['h', '##a', '##s'], 'a': ['a'], 'cute': ['c', '##u', '##t', '##e'], 'cat': ['c', '##a', '##t'], 'you': ['y', '##o', '##u'], 'are': ['a', '##r', '##e'], 'very': ['v', '##e', '##r', '##y'], 'give': ['g', '##i', '##v', '##e'], 'hug': ['h', '##u', '##g']}\n"
     ]
    }
   ],
   "source": [
    "splits = {\n",
    "    word: [c if i == 0 else f\"##{c}\" for i, c in enumerate(word)]\n",
    "    for word in stats.keys()\n",
    "}\n",
    "print(\"splits:\", splits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1e1bc015-9b09-4172-ac2e-5a744c3e43d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('喜', '##欢'): 0.5\n",
      "('苹', '##果'): 0.5\n",
      "('##果', '##派'): 0.5\n",
      "('l', '##i'): 0.5\n",
      "('##i', '##k'): 0.5\n",
      "('##k', '##e'): 0.125\n"
     ]
    }
   ],
   "source": [
    "def compute_pair_scores(splits):\n",
    "    letter_freqs = defaultdict(int)\n",
    "    pair_freqs = defaultdict(int)\n",
    "    for word, freq in stats.items():\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            letter_freqs[split[0]] += freq\n",
    "            continue\n",
    "        for i in range(len(split) - 1):\n",
    "            pair = (split[i], split[i + 1])\n",
    "            letter_freqs[split[i]] += freq\n",
    "            pair_freqs[pair] += freq\n",
    "        letter_freqs[split[-1]] += freq\n",
    "\n",
    "    scores = {\n",
    "        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])\n",
    "        for pair, freq in pair_freqs.items()\n",
    "    }\n",
    "    return scores\n",
    "\n",
    "pair_scores = compute_pair_scores(splits)\n",
    "for i, key in enumerate(pair_scores.keys()):\n",
    "    print(f\"{key}: {pair_scores[key]}\")\n",
    "    if i >= 5:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5f0e1163-21ee-4105-bd48-66482bdd59e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('S', '##h') 1.0\n"
     ]
    }
   ],
   "source": [
    "best_pair = \"\"\n",
    "max_score = None\n",
    "for pair, score in pair_scores.items():\n",
    "    if max_score is None or max_score < score:\n",
    "        best_pair = pair\n",
    "        max_score = score\n",
    "\n",
    "print(best_pair, max_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4084f9cb-0246-4bc0-b17a-c01bee69c3ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab: ['##a', '##e', '##g', '##h', '##i', '##k', '##l', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##y', '##果', '##欢', '##派', 'I', 'S', 'a', 'c', 'e', 'g', 'h', 'l', 't', 'v', 'y', '不', '他', '吃', '喜', '我', '苹', 'Sh', '喜欢', '苹果', '苹果派', 'li', 'lik', 'gi', 'giv', '##pl', '##ppl', '##ry', 'to', 'yo', 'ea', 'eat']\n"
     ]
    }
   ],
   "source": [
    "def merge_pair(a, b, splits):\n",
    "    for word in stats:\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        i = 0\n",
    "        while i < len(split) - 1:\n",
    "            if split[i] == a and split[i + 1] == b:\n",
    "                merge = a + b[2:] if b.startswith(\"##\") else a + b\n",
    "                split = split[:i] + [merge] + split[i + 2 :]\n",
    "            else:\n",
    "                i += 1\n",
    "        splits[word] = split\n",
    "    return splits\n",
    "\n",
    "vocab_size = 50\n",
    "while len(vocab) < vocab_size:\n",
    "    scores = compute_pair_scores(splits)\n",
    "    best_pair, max_score = \"\", None\n",
    "    for pair, score in scores.items():\n",
    "        if max_score is None or max_score < score:\n",
    "            best_pair = pair\n",
    "            max_score = score\n",
    "    splits = merge_pair(*best_pair, splits)\n",
    "    new_token = (\n",
    "        best_pair[0] + best_pair[1][2:]\n",
    "        if best_pair[1].startswith(\"##\")\n",
    "        else best_pair[0] + best_pair[1]\n",
    "    )\n",
    "    vocab.append(new_token)\n",
    "\n",
    "print(\"vocab:\", vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cad0484-c620-4074-88ef-a86a64297c83",
   "metadata": {},
   "source": [
    "### Byte-Pair Encoding (BPE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9048c7d1-5f7c-49b6-baff-6abd65f65639",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = [\n",
    "    \"我\",\n",
    "    \"喜欢\",\n",
    "    \"吃\",\n",
    "    \"苹果\",\n",
    "    \"他\",\n",
    "    \"不\",\n",
    "    \"喜欢\",\n",
    "    \"吃\",\n",
    "    \"苹果派\",\n",
    "    \"I like to eat apples\",\n",
    "    \"She has a cute cat\",\n",
    "    \"you are very cute\",\n",
    "    \"give you a hug\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3ecdc5f7-e981-4f4d-9ef9-0aa7a2e807c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stats: defaultdict(<class 'int'>, {'我': 1, '喜欢': 2, '吃': 2, '苹果': 1, '他': 1, '不': 1, '苹果派': 1, 'I': 1, 'like': 1, 'to': 1, 'eat': 1, 'apples': 1, 'She': 1, 'has': 1, 'a': 2, 'cute': 2, 'cat': 1, 'you': 2, 'are': 1, 'very': 1, 'give': 1, 'hug': 1})\n",
      "alphabet: ['I', 'S', 'a', 'c', 'e', 'g', 'h', 'i', 'k', 'l', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', '不', '他', '吃', '喜', '我', '果', '欢', '派', '苹']\n"
     ]
    }
   ],
   "source": [
    "# 构建频率统计\n",
    "def build_stats(sentences):\n",
    "    stats = defaultdict(int)\n",
    "    for sentence in sentences:\n",
    "        symbols = sentence.split()\n",
    "        for symbol in symbols:\n",
    "            stats[symbol] += 1\n",
    "    return stats\n",
    "\n",
    "stats = build_stats(sentences)\n",
    "print(\"stats:\", stats)\n",
    "\n",
    "alphabet = []\n",
    "for word in stats.keys():\n",
    "    for letter in word:\n",
    "        if letter not in alphabet:\n",
    "            alphabet.append(letter)\n",
    "alphabet.sort()\n",
    "\n",
    "# 初始词表\n",
    "vocab = alphabet.copy()\n",
    "print(\"alphabet:\", alphabet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e438e15-dce4-479b-9447-083ca242c716",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "splits: {'我': ['我'], '喜欢': ['喜', '欢'], '吃': ['吃'], '苹果': ['苹', '果'], '他': ['他'], '不': ['不'], '苹果派': ['苹', '果', '派'], 'I': ['I'], 'like': ['l', 'i', 'k', 'e'], 'to': ['t', 'o'], 'eat': ['e', 'a', 't'], 'apples': ['a', 'p', 'p', 'l', 'e', 's'], 'She': ['S', 'h', 'e'], 'has': ['h', 'a', 's'], 'a': ['a'], 'cute': ['c', 'u', 't', 'e'], 'cat': ['c', 'a', 't'], 'you': ['y', 'o', 'u'], 'are': ['a', 'r', 'e'], 'very': ['v', 'e', 'r', 'y'], 'give': ['g', 'i', 'v', 'e'], 'hug': ['h', 'u', 'g']}\n",
      "('喜', '欢'): 2\n",
      "('苹', '果'): 2\n",
      "('果', '派'): 1\n",
      "('l', 'i'): 1\n",
      "('i', 'k'): 1\n",
      "('k', 'e'): 1\n"
     ]
    }
   ],
   "source": [
    "splits = {word: [c for c in word] for word in stats.keys()}\n",
    "print(\"splits:\", splits)\n",
    "\n",
    "def compute_pair_freqs(splits):\n",
    "    pair_freqs = defaultdict(int)\n",
    "    for word, freq in stats.items():\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        for i in range(len(split) - 1):\n",
    "            pair = (split[i], split[i + 1])\n",
    "            pair_freqs[pair] += freq\n",
    "    return pair_freqs\n",
    "pair_freqs = compute_pair_freqs(splits)\n",
    "\n",
    "for i, key in enumerate(pair_freqs.keys()):\n",
    "    print(f\"{key}: {pair_freqs[key]}\")\n",
    "    if i >= 5:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "14b12e7e-8910-40dc-8217-1f46d603d33f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('喜', '欢') 2\n"
     ]
    }
   ],
   "source": [
    "best_pair = \"\"\n",
    "max_freq = None\n",
    "for pair, freq in pair_freqs.items():\n",
    "    if max_freq is None or max_freq < freq:\n",
    "        best_pair = pair\n",
    "        max_freq = freq\n",
    "\n",
    "print(best_pair, max_freq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "68c3d67a-fb1b-4573-a629-29cb7c09e5a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "merges: {('喜', '欢'): '喜欢', ('苹', '果'): '苹果', ('a', 't'): 'at', ('c', 'u'): 'cu', ('cu', 't'): 'cut', ('cut', 'e'): 'cute', ('y', 'o'): 'yo', ('yo', 'u'): 'you', ('v', 'e'): 've', ('苹果', '派'): '苹果派', ('l', 'i'): 'li', ('li', 'k'): 'lik', ('lik', 'e'): 'like', ('t', 'o'): 'to', ('e', 'at'): 'eat', ('a', 'p'): 'ap', ('ap', 'p'): 'app', ('app', 'l'): 'appl', ('appl', 'e'): 'apple', ('apple', 's'): 'apples', ('S', 'h'): 'Sh', ('Sh', 'e'): 'She', ('h', 'a'): 'ha'}\n",
      "vocab: ['I', 'S', 'a', 'c', 'e', 'g', 'h', 'i', 'k', 'l', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', '不', '他', '吃', '喜', '我', '果', '欢', '派', '苹', '喜欢', '苹果', 'at', 'cu', 'cut', 'cute', 'yo', 'you', 've', '苹果派', 'li', 'lik', 'like', 'to', 'eat', 'ap', 'app', 'appl', 'apple', 'apples', 'Sh', 'She', 'ha']\n"
     ]
    }
   ],
   "source": [
    "def merge_pair(a, b, splits):\n",
    "    for word in stats:\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "\n",
    "        i = 0\n",
    "        while i < len(split) - 1:\n",
    "            if split[i] == a and split[i + 1] == b:\n",
    "                split = split[:i] + [a + b] + split[i + 2 :]\n",
    "            else:\n",
    "                i += 1\n",
    "        splits[word] = split\n",
    "    return splits\n",
    "\n",
    "# 假设我们想要的词典为50\n",
    "merges = {}\n",
    "vocab_size = 50\n",
    "\n",
    "while len(vocab) < vocab_size:\n",
    "    pair_freqs = compute_pair_freqs(splits)\n",
    "    best_pair = \"\"\n",
    "    max_freq = None\n",
    "    for pair, freq in pair_freqs.items():\n",
    "        if max_freq is None or max_freq < freq:\n",
    "            best_pair = pair\n",
    "            max_freq = freq\n",
    "    splits = merge_pair(*best_pair, splits)\n",
    "    merges[best_pair] = best_pair[0] + best_pair[1]\n",
    "    vocab.append(best_pair[0] + best_pair[1])\n",
    "\n",
    "print(\"merges:\", merges)\n",
    "print(\"vocab:\", vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4589d599-0c4f-467c-8c03-bec879f93982",
   "metadata": {},
   "source": [
    "### Byte-level BPE(BBPE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "14cdfdb4-a3c6-414e-9a02-34e3e04398b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initial_vocab: [b'\\x00', b'\\x01', b'\\x02', b'\\x03', b'\\x04', b'\\x05', b'\\x06', b'\\x07', b'\\x08', b'\\t', b'\\n', b'\\x0b', b'\\x0c', b'\\r', b'\\x0e', b'\\x0f', b'\\x10', b'\\x11', b'\\x12', b'\\x13', b'\\x14', b'\\x15', b'\\x16', b'\\x17', b'\\x18', b'\\x19', b'\\x1a', b'\\x1b', b'\\x1c', b'\\x1d', b'\\x1e', b'\\x1f', b' ', b'!', b'\"', b'#', b'$', b'%', b'&', b\"'\", b'(', b')', b'*', b'+', b',', b'-', b'.', b'/', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', b'<', b'=', b'>', b'?', b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', b'[', b'\\\\', b']', b'^', b'_', b'`', b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'{', b'|', b'}', b'~', b'\\x7f', b'\\x80', b'\\x81', b'\\x82', b'\\x83', b'\\x84', b'\\x85', b'\\x86', b'\\x87', b'\\x88', b'\\x89', b'\\x8a', b'\\x8b', b'\\x8c', b'\\x8d', b'\\x8e', b'\\x8f', b'\\x90', b'\\x91', b'\\x92', b'\\x93', b'\\x94', b'\\x95', b'\\x96', b'\\x97', b'\\x98', b'\\x99', b'\\x9a', b'\\x9b', b'\\x9c', b'\\x9d', b'\\x9e', b'\\x9f', b'\\xa0', b'\\xa1', b'\\xa2', b'\\xa3', b'\\xa4', b'\\xa5', b'\\xa6', b'\\xa7', b'\\xa8', b'\\xa9', b'\\xaa', b'\\xab', b'\\xac', b'\\xad', b'\\xae', b'\\xaf', b'\\xb0', b'\\xb1', b'\\xb2', b'\\xb3', b'\\xb4', b'\\xb5', b'\\xb6', b'\\xb7', b'\\xb8', b'\\xb9', b'\\xba', b'\\xbb', b'\\xbc', b'\\xbd', b'\\xbe', b'\\xbf', b'\\xc0', b'\\xc1', b'\\xc2', b'\\xc3', b'\\xc4', b'\\xc5', b'\\xc6', b'\\xc7', b'\\xc8', b'\\xc9', b'\\xca', b'\\xcb', b'\\xcc', b'\\xcd', b'\\xce', b'\\xcf', b'\\xd0', b'\\xd1', b'\\xd2', b'\\xd3', b'\\xd4', b'\\xd5', b'\\xd6', b'\\xd7', b'\\xd8', b'\\xd9', b'\\xda', b'\\xdb', b'\\xdc', b'\\xdd', b'\\xde', b'\\xdf', b'\\xe0', b'\\xe1', b'\\xe2', b'\\xe3', b'\\xe4', b'\\xe5', b'\\xe6', b'\\xe7', b'\\xe8', b'\\xe9', b'\\xea', b'\\xeb', b'\\xec', b'\\xed', b'\\xee', b'\\xef', b'\\xf0', b'\\xf1', b'\\xf2', b'\\xf3', b'\\xf4', b'\\xf5', b'\\xf6', b'\\xf7', b'\\xf8', b'\\xf9', b'\\xfa', b'\\xfb', b'\\xfc', b'\\xfd', b'\\xfe', b'\\xff']\n",
      "vocab: [b'\\x00', b'\\x01', b'\\x02', b'\\x03', b'\\x04', b'\\x05', b'\\x06', b'\\x07', b'\\x08', b'\\t', b'\\n', b'\\x0b', b'\\x0c', b'\\r', b'\\x0e', b'\\x0f', b'\\x10', b'\\x11', b'\\x12', b'\\x13', b'\\x14', b'\\x15', b'\\x16', b'\\x17', b'\\x18', b'\\x19', b'\\x1a', b'\\x1b', b'\\x1c', b'\\x1d', b'\\x1e', b'\\x1f', b' ', b'!', b'\"', b'#', b'$', b'%', b'&', b\"'\", b'(', b')', b'*', b'+', b',', b'-', b'.', b'/', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', b'<', b'=', b'>', b'?', b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', b'[', b'\\\\', b']', b'^', b'_', b'`', b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'{', b'|', b'}', b'~', b'\\x7f', b'\\x80', b'\\x81', b'\\x82', b'\\x83', b'\\x84', b'\\x85', b'\\x86', b'\\x87', b'\\x88', b'\\x89', b'\\x8a', b'\\x8b', b'\\x8c', b'\\x8d', b'\\x8e', b'\\x8f', b'\\x90', b'\\x91', b'\\x92', b'\\x93', b'\\x94', b'\\x95', b'\\x96', b'\\x97', b'\\x98', b'\\x99', b'\\x9a', b'\\x9b', b'\\x9c', b'\\x9d', b'\\x9e', b'\\x9f', b'\\xa0', b'\\xa1', b'\\xa2', b'\\xa3', b'\\xa4', b'\\xa5', b'\\xa6', b'\\xa7', b'\\xa8', b'\\xa9', b'\\xaa', b'\\xab', b'\\xac', b'\\xad', b'\\xae', b'\\xaf', b'\\xb0', b'\\xb1', b'\\xb2', b'\\xb3', b'\\xb4', b'\\xb5', b'\\xb6', b'\\xb7', b'\\xb8', b'\\xb9', b'\\xba', b'\\xbb', b'\\xbc', b'\\xbd', b'\\xbe', b'\\xbf', b'\\xc0', b'\\xc1', b'\\xc2', b'\\xc3', b'\\xc4', b'\\xc5', b'\\xc6', b'\\xc7', b'\\xc8', b'\\xc9', b'\\xca', b'\\xcb', b'\\xcc', b'\\xcd', b'\\xce', b'\\xcf', b'\\xd0', b'\\xd1', b'\\xd2', b'\\xd3', b'\\xd4', b'\\xd5', b'\\xd6', b'\\xd7', b'\\xd8', b'\\xd9', b'\\xda', b'\\xdb', b'\\xdc', b'\\xdd', b'\\xde', b'\\xdf', b'\\xe0', b'\\xe1', b'\\xe2', b'\\xe3', b'\\xe4', b'\\xe5', b'\\xe6', b'\\xe7', b'\\xe8', b'\\xe9', b'\\xea', b'\\xeb', b'\\xec', b'\\xed', b'\\xee', b'\\xef', b'\\xf0', b'\\xf1', b'\\xf2', b'\\xf3', b'\\xf4', b'\\xf5', b'\\xf6', b'\\xf7', b'\\xf8', b'\\xf9', b'\\xfa', b'\\xfb', b'\\xfc', b'\\xfd', b'\\xfe', b'\\xff']\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "sentences = [\n",
    "    \"我\",\n",
    "    \"喜欢\",\n",
    "    \"吃\",\n",
    "    \"苹果\",\n",
    "    \"他\",\n",
    "    \"不\",\n",
    "    \"喜欢\",\n",
    "    \"吃\",\n",
    "    \"苹果派\",\n",
    "    \"I like to eat apples\",\n",
    "    \"She has a cute cat\",\n",
    "    \"you are very cute\",\n",
    "    \"give you a hug\",\n",
    "]\n",
    "# 构建初始词汇表，包含一个字节的256个表示\n",
    "initial_vocab = [bytes([byte]) for byte in range(256)]\n",
    "vocab = initial_vocab.copy()\n",
    "print(\"initial_vocab:\", initial_vocab)\n",
    "\n",
    "# 构建频率统计\n",
    "def build_stats(sentences):\n",
    "    stats = defaultdict(int)\n",
    "    for sentence in sentences:\n",
    "        symbols = sentence.split()\n",
    "        for symbol in symbols:\n",
    "            stats[symbol.encode(\"utf-8\")] += 1\n",
    "    return stats\n",
    "stats = build_stats(sentences)\n",
    "\n",
    "splits = {word: [byte for byte in word] for word in stats.keys()}\n",
    "def compute_pair_freqs(splits):\n",
    "    pair_freqs = defaultdict(int)\n",
    "    for word, freq in stats.items():\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        for i in range(len(split) - 1):\n",
    "            pair = (split[i], split[i + 1])\n",
    "            pair_freqs[pair] += freq\n",
    "    return pair_freqs\n",
    "\n",
    "pair_freqs = compute_pair_freqs(splits)\n",
    "\n",
    "def merge_pair(pair, splits):\n",
    "    merged_byte = bytes(pair)\n",
    "    for word in stats:\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        i = 0\n",
    "        while i < len(split) - 1:\n",
    "            if split[i:i+2] == pair:  # 检查分割中是否有这对字节\n",
    "                split = split[:i] + [merged_byte] + split[i + 2 :]\n",
    "            else:\n",
    "                i += 1\n",
    "        splits[word] = split\n",
    "    return splits\n",
    "\n",
    "vocab_size = 50\n",
    "while len(vocab) < vocab_size:\n",
    "    pair_freqs = compute_pair_freqs(splits)\n",
    "    best_pair = ()\n",
    "    max_freq = None\n",
    "    for pair, freq in pair_freqs.items():\n",
    "        if max_freq is None or max_freq < freq:\n",
    "            best_pair = pair\n",
    "            max_freq = freq\n",
    "    splits = merge_pair(best_pair, splits)\n",
    "    merged_byte = bytes(best_pair)\n",
    "\n",
    "print(\"vocab:\", vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71a225fd-f5cf-4e11-a178-30d5f96c7cc8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
